{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 640, 32, 32])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class Resnet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim_in, dim_out):\n",
    "        super().__init__()\n",
    "\n",
    "        self.time = torch.nn.Sequential(\n",
    "            torch.nn.SiLU(),\n",
    "            torch.torch.nn.Linear(1280, dim_out),\n",
    "            torch.nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),\n",
    "        )\n",
    "\n",
    "        self.s0 = torch.nn.Sequential(\n",
    "            torch.torch.nn.GroupNorm(num_groups=32,\n",
    "                                     num_channels=dim_in,\n",
    "                                     eps=1e-05,\n",
    "                                     affine=True),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.torch.nn.Conv2d(dim_in,\n",
    "                                  dim_out,\n",
    "                                  kernel_size=3,\n",
    "                                  stride=1,\n",
    "                                  padding=1),\n",
    "        )\n",
    "\n",
    "        self.s1 = torch.nn.Sequential(\n",
    "            torch.torch.nn.GroupNorm(num_groups=32,\n",
    "                                     num_channels=dim_out,\n",
    "                                     eps=1e-05,\n",
    "                                     affine=True),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.torch.nn.Conv2d(dim_out,\n",
    "                                  dim_out,\n",
    "                                  kernel_size=3,\n",
    "                                  stride=1,\n",
    "                                  padding=1),\n",
    "        )\n",
    "\n",
    "        self.res = None\n",
    "        if dim_in != dim_out:\n",
    "            self.res = torch.torch.nn.Conv2d(dim_in,\n",
    "                                             dim_out,\n",
    "                                             kernel_size=1,\n",
    "                                             stride=1,\n",
    "                                             padding=0)\n",
    "\n",
    "    def forward(self, x, time):\n",
    "        #x -> [1, 320, 64, 64]\n",
    "        #time -> [1, 1280]\n",
    "\n",
    "        res = x\n",
    "\n",
    "        #[1, 1280] -> [1, 640, 1, 1]\n",
    "        time = self.time(time)\n",
    "\n",
    "        #[1, 320, 64, 64] -> [1, 640, 32, 32]\n",
    "        x = self.s0(x) + time\n",
    "\n",
    "        #维度不变\n",
    "        #[1, 640, 32, 32]\n",
    "        x = self.s1(x)\n",
    "\n",
    "        #[1, 320, 64, 64] -> [1, 640, 32, 32]\n",
    "        if self.res:\n",
    "            res = self.res(res)\n",
    "\n",
    "        #维度不变\n",
    "        #[1, 640, 32, 32]\n",
    "        x = res + x\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "Resnet(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 1280)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4096, 320])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CrossAttention(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim_q, dim_kv):\n",
    "        #dim_q -> 320\n",
    "        #dim_kv -> 768\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.dim_q = dim_q\n",
    "\n",
    "        self.q = torch.nn.Linear(dim_q, dim_q, bias=False)\n",
    "        self.k = torch.nn.Linear(dim_kv, dim_q, bias=False)\n",
    "        self.v = torch.nn.Linear(dim_kv, dim_q, bias=False)\n",
    "\n",
    "        self.out = torch.nn.Linear(dim_q, dim_q)\n",
    "\n",
    "    def forward(self, q, kv):\n",
    "        #x -> [1, 4096, 320]\n",
    "        #kv -> [1, 77, 768]\n",
    "\n",
    "        #[1, 4096, 320] -> [1, 4096, 320]\n",
    "        q = self.q(q)\n",
    "        #[1, 77, 768] -> [1, 77, 320]\n",
    "        k = self.k(kv)\n",
    "        #[1, 77, 768] -> [1, 77, 320]\n",
    "        v = self.v(kv)\n",
    "\n",
    "        def reshape(x):\n",
    "            #x -> [1, 4096, 320]\n",
    "            b, lens, dim = x.shape\n",
    "\n",
    "            #[1, 4096, 320] -> [1, 4096, 8, 40]\n",
    "            x = x.reshape(b, lens, 8, dim // 8)\n",
    "\n",
    "            #[1, 4096, 8, 40] -> [1, 8, 4096, 40]\n",
    "            x = x.transpose(1, 2)\n",
    "\n",
    "            #[1, 8, 4096, 40] -> [8, 4096, 40]\n",
    "            x = x.reshape(b * 8, lens, dim // 8)\n",
    "\n",
    "            return x\n",
    "\n",
    "        #[1, 4096, 320] -> [8, 4096, 40]\n",
    "        q = reshape(q)\n",
    "        #[1, 77, 320] -> [8, 77, 40]\n",
    "        k = reshape(k)\n",
    "        #[1, 77, 320] -> [8, 77, 40]\n",
    "        v = reshape(v)\n",
    "\n",
    "        #[8, 4096, 40] * [8, 40, 77] -> [8, 4096, 77]\n",
    "        #atten = q.bmm(k.transpose(1, 2)) * (self.dim_q // 8)**-0.5\n",
    "\n",
    "        #从数学上是等价的,但是在实际计算时会产生很小的误差\n",
    "        atten = torch.baddbmm(\n",
    "            torch.empty(q.shape[0], q.shape[1], k.shape[1], device=q.device),\n",
    "            q,\n",
    "            k.transpose(1, 2),\n",
    "            beta=0,\n",
    "            alpha=(self.dim_q // 8)**-0.5,\n",
    "        )\n",
    "\n",
    "        atten = atten.softmax(dim=-1)\n",
    "\n",
    "        #[8, 4096, 77] * [8, 77, 40] -> [8, 4096, 40]\n",
    "        atten = atten.bmm(v)\n",
    "\n",
    "        def reshape(x):\n",
    "            #x -> [8, 4096, 40]\n",
    "            b, lens, dim = x.shape\n",
    "\n",
    "            #[8, 4096, 40] -> [1, 8, 4096, 40]\n",
    "            x = x.reshape(b // 8, 8, lens, dim)\n",
    "\n",
    "            #[1, 8, 4096, 40] -> [1, 4096, 8, 40]\n",
    "            x = x.transpose(1, 2)\n",
    "\n",
    "            #[1, 4096, 320]\n",
    "            x = x.reshape(b // 8, lens, dim * 8)\n",
    "\n",
    "            return x\n",
    "\n",
    "        #[8, 4096, 40] -> [1, 4096, 320]\n",
    "        atten = reshape(atten)\n",
    "\n",
    "        #[1, 4096, 320] -> [1, 4096, 320]\n",
    "        atten = self.out(atten)\n",
    "\n",
    "        return atten\n",
    "\n",
    "\n",
    "CrossAttention(320, 768)(torch.randn(1, 4096, 320), torch.randn(1, 77,\n",
    "                                                                768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 320, 64, 64])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Transformer(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "\n",
    "        #in\n",
    "        self.norm_in = torch.nn.GroupNorm(num_groups=32,\n",
    "                                          num_channels=dim,\n",
    "                                          eps=1e-6,\n",
    "                                          affine=True)\n",
    "        self.cnn_in = torch.nn.Conv2d(dim,\n",
    "                                      dim,\n",
    "                                      kernel_size=1,\n",
    "                                      stride=1,\n",
    "                                      padding=0)\n",
    "\n",
    "        #atten\n",
    "        self.norm_atten0 = torch.nn.LayerNorm(dim, elementwise_affine=True)\n",
    "        self.atten1 = CrossAttention(dim, dim)\n",
    "        self.norm_atten1 = torch.nn.LayerNorm(dim, elementwise_affine=True)\n",
    "        self.atten2 = CrossAttention(dim, 768)\n",
    "\n",
    "        #act\n",
    "        self.norm_act = torch.nn.LayerNorm(dim, elementwise_affine=True)\n",
    "        self.fc0 = torch.nn.Linear(dim, dim * 8)\n",
    "        self.act = torch.nn.GELU()\n",
    "        self.fc1 = torch.nn.Linear(dim * 4, dim)\n",
    "\n",
    "        #out\n",
    "        self.cnn_out = torch.nn.Conv2d(dim,\n",
    "                                       dim,\n",
    "                                       kernel_size=1,\n",
    "                                       stride=1,\n",
    "                                       padding=0)\n",
    "\n",
    "    def forward(self, q, kv):\n",
    "        #q -> [1, 320, 64, 64]\n",
    "        #kv -> [1, 77, 768]\n",
    "        b, _, h, w = q.shape\n",
    "        res1 = q\n",
    "\n",
    "        #----in----\n",
    "        #维度不变\n",
    "        #[1, 320, 64, 64]\n",
    "        q = self.cnn_in(self.norm_in(q))\n",
    "\n",
    "        #[1, 320, 64, 64] -> [1, 64, 64, 320] -> [1, 4096, 320]\n",
    "        q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)\n",
    "\n",
    "        #----atten----\n",
    "        #维度不变\n",
    "        #[1, 4096, 320]\n",
    "        q = self.atten1(q=self.norm_atten0(q), kv=self.norm_atten0(q)) + q\n",
    "        q = self.atten2(q=self.norm_atten1(q), kv=kv) + q\n",
    "\n",
    "        #----act----\n",
    "        #[1, 4096, 320]\n",
    "        res2 = q\n",
    "\n",
    "        #[1, 4096, 320] -> [1, 4096, 2560]\n",
    "        q = self.fc0(self.norm_act(q))\n",
    "\n",
    "        #1280\n",
    "        d = q.shape[2] // 2\n",
    "\n",
    "        #[1, 4096, 1280] * [1, 4096, 1280] -> [1, 4096, 1280]\n",
    "        q = q[:, :, :d] * self.act(q[:, :, d:])\n",
    "\n",
    "        #[1, 4096, 1280] -> [1, 4096, 320]\n",
    "        q = self.fc1(q) + res2\n",
    "\n",
    "        #----out----\n",
    "        #[1, 4096, 320] -> [1, 64, 64, 320] -> [1, 320, 64, 64]\n",
    "        q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()\n",
    "\n",
    "        #维度不变\n",
    "        #[1, 320, 64, 64]\n",
    "        q = self.cnn_out(q) + res1\n",
    "\n",
    "        return q\n",
    "\n",
    "\n",
    "Transformer(320)(torch.randn(1, 320, 64, 64), torch.randn(1, 77, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 640, 16, 16])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class DownBlock(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim_in, dim_out):\n",
    "        super().__init__()\n",
    "\n",
    "        self.tf0 = Transformer(dim_out)\n",
    "        self.res0 = Resnet(dim_in, dim_out)\n",
    "\n",
    "        self.tf1 = Transformer(dim_out)\n",
    "        self.res1 = Resnet(dim_out, dim_out)\n",
    "\n",
    "        self.out = torch.nn.Conv2d(dim_out,\n",
    "                                   dim_out,\n",
    "                                   kernel_size=3,\n",
    "                                   stride=2,\n",
    "                                   padding=1)\n",
    "\n",
    "    def forward(self, out_vae, out_encoder, time):\n",
    "        outs = []\n",
    "\n",
    "        out_vae = self.res0(out_vae, time)\n",
    "        out_vae = self.tf0(out_vae, out_encoder)\n",
    "        outs.append(out_vae)\n",
    "\n",
    "        out_vae = self.res1(out_vae, time)\n",
    "        out_vae = self.tf1(out_vae, out_encoder)\n",
    "        outs.append(out_vae)\n",
    "\n",
    "        out_vae = self.out(out_vae)\n",
    "        outs.append(out_vae)\n",
    "\n",
    "        return out_vae, outs\n",
    "\n",
    "\n",
    "DownBlock(320, 640)(torch.randn(1, 320, 32, 32), torch.randn(1, 77, 768),\n",
    "                    torch.randn(1, 1280))[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 640, 64, 64])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class UpBlock(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, dim_in, dim_out, dim_prev, add_up):\n",
    "        super().__init__()\n",
    "\n",
    "        self.res0 = Resnet(dim_out + dim_prev, dim_out)\n",
    "        self.res1 = Resnet(dim_out + dim_out, dim_out)\n",
    "        self.res2 = Resnet(dim_in + dim_out, dim_out)\n",
    "\n",
    "        self.tf0 = Transformer(dim_out)\n",
    "        self.tf1 = Transformer(dim_out)\n",
    "        self.tf2 = Transformer(dim_out)\n",
    "\n",
    "        self.out = None\n",
    "        if add_up:\n",
    "            self.out = torch.nn.Sequential(\n",
    "                torch.nn.Upsample(scale_factor=2, mode='nearest'),\n",
    "                torch.nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1),\n",
    "            )\n",
    "\n",
    "    def forward(self, out_vae, out_encoder, time, out_down):\n",
    "        out_vae = self.res0(torch.cat([out_vae, out_down.pop()], dim=1), time)\n",
    "        out_vae = self.tf0(out_vae, out_encoder)\n",
    "\n",
    "        out_vae = self.res1(torch.cat([out_vae, out_down.pop()], dim=1), time)\n",
    "        out_vae = self.tf1(out_vae, out_encoder)\n",
    "\n",
    "        out_vae = self.res2(torch.cat([out_vae, out_down.pop()], dim=1), time)\n",
    "        out_vae = self.tf2(out_vae, out_encoder)\n",
    "\n",
    "        if self.out:\n",
    "            out_vae = self.out(out_vae)\n",
    "\n",
    "        return out_vae\n",
    "\n",
    "\n",
    "UpBlock(320, 640, 1280, True)(torch.randn(1, 1280, 32, 32),\n",
    "                              torch.randn(1, 77, 768), torch.randn(1, 1280), [\n",
    "                                  torch.randn(1, 320, 32, 32),\n",
    "                                  torch.randn(1, 640, 32, 32),\n",
    "                                  torch.randn(1, 640, 32, 32)\n",
    "                              ]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 64, 64])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class UNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        #in\n",
    "        self.in_vae = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)\n",
    "\n",
    "        self.in_time = torch.nn.Sequential(\n",
    "            torch.nn.Linear(320, 1280),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Linear(1280, 1280),\n",
    "        )\n",
    "\n",
    "        #down\n",
    "        self.down_block0 = DownBlock(320, 320)\n",
    "        self.down_block1 = DownBlock(320, 640)\n",
    "        self.down_block2 = DownBlock(640, 1280)\n",
    "\n",
    "        self.down_res0 = Resnet(1280, 1280)\n",
    "        self.down_res1 = Resnet(1280, 1280)\n",
    "\n",
    "        #mid\n",
    "        self.mid_res0 = Resnet(1280, 1280)\n",
    "        self.mid_tf = Transformer(1280)\n",
    "        self.mid_res1 = Resnet(1280, 1280)\n",
    "\n",
    "        #up\n",
    "        self.up_res0 = Resnet(2560, 1280)\n",
    "        self.up_res1 = Resnet(2560, 1280)\n",
    "        self.up_res2 = Resnet(2560, 1280)\n",
    "\n",
    "        self.up_in = torch.nn.Sequential(\n",
    "            torch.nn.Upsample(scale_factor=2, mode='nearest'),\n",
    "            torch.nn.Conv2d(1280, 1280, kernel_size=3, padding=1),\n",
    "        )\n",
    "\n",
    "        self.up_block0 = UpBlock(640, 1280, 1280, True)\n",
    "        self.up_block1 = UpBlock(320, 640, 1280, True)\n",
    "        self.up_block2 = UpBlock(320, 320, 640, False)\n",
    "\n",
    "        #out\n",
    "        self.out = torch.nn.Sequential(\n",
    "            torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Conv2d(320, 4, kernel_size=3, padding=1),\n",
    "        )\n",
    "\n",
    "    def forward(self, out_vae, out_encoder, time):\n",
    "        #out_vae -> [1, 4, 64, 64]\n",
    "        #out_encoder -> [1, 77, 768]\n",
    "        #time -> [1]\n",
    "\n",
    "        #----in----\n",
    "        #[1, 4, 64, 64] -> [1, 320, 64, 64]\n",
    "        out_vae = self.in_vae(out_vae)\n",
    "\n",
    "        def get_time_embed(t):\n",
    "            #-9.210340371976184 = -math.log(10000)\n",
    "            e = torch.arange(160) * -9.210340371976184 / 160\n",
    "            e = e.exp().to(t.device) * t\n",
    "\n",
    "            #[160+160] -> [320] -> [1, 320]\n",
    "            e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)\n",
    "\n",
    "            return e\n",
    "\n",
    "        #[1] -> [1, 320]\n",
    "        time = get_time_embed(time)\n",
    "        #[1, 320] -> [1, 1280]\n",
    "        time = self.in_time(time)\n",
    "\n",
    "        #----down----\n",
    "        #[1, 320, 64, 64]\n",
    "        #[1, 320, 64, 64]\n",
    "        #[1, 320, 64, 64]\n",
    "        #[1, 320, 32, 32]\n",
    "        #[1, 640, 32, 32]\n",
    "        #[1, 640, 32, 32]\n",
    "        #[1, 640, 16, 16]\n",
    "        #[1, 1280, 16, 16]\n",
    "        #[1, 1280, 16, 16]\n",
    "        #[1, 1280, 8, 8]\n",
    "        #[1, 1280, 8, 8]\n",
    "        #[1, 1280, 8, 8]\n",
    "        out_down = [out_vae]\n",
    "\n",
    "        #[1, 320, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 32, 32]\n",
    "        #out -> [1, 320, 64, 64],[1, 320, 64, 64][1, 320, 32, 32]\n",
    "        out_vae, out = self.down_block0(out_vae=out_vae,\n",
    "                                        out_encoder=out_encoder,\n",
    "                                        time=time)\n",
    "        out_down.extend(out)\n",
    "\n",
    "        #[1, 320, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 16, 16]\n",
    "        #out -> [1, 640, 32, 32],[1, 640, 32, 32],[1, 640, 16, 16]\n",
    "        out_vae, out = self.down_block1(out_vae=out_vae,\n",
    "                                        out_encoder=out_encoder,\n",
    "                                        time=time)\n",
    "        out_down.extend(out)\n",
    "\n",
    "        #[1, 640, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        #out -> [1, 1280, 16, 16],[1, 1280, 16, 16],[1, 1280, 8, 8]\n",
    "        out_vae, out = self.down_block2(out_vae=out_vae,\n",
    "                                        out_encoder=out_encoder,\n",
    "                                        time=time)\n",
    "        out_down.extend(out)\n",
    "\n",
    "        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.down_res0(out_vae, time)\n",
    "        out_down.append(out_vae)\n",
    "\n",
    "        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.down_res1(out_vae, time)\n",
    "        out_down.append(out_vae)\n",
    "\n",
    "        #----mid----\n",
    "        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.mid_res0(out_vae, time)\n",
    "\n",
    "        #[1, 1280, 8, 8],[1, 77, 768] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.mid_tf(out_vae, out_encoder)\n",
    "\n",
    "        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.mid_res1(out_vae, time)\n",
    "\n",
    "        #----up----\n",
    "        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.up_res0(torch.cat([out_vae, out_down.pop()], dim=1),\n",
    "                               time)\n",
    "\n",
    "        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.up_res1(torch.cat([out_vae, out_down.pop()], dim=1),\n",
    "                               time)\n",
    "\n",
    "        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]\n",
    "        out_vae = self.up_res2(torch.cat([out_vae, out_down.pop()], dim=1),\n",
    "                               time)\n",
    "\n",
    "        #[1, 1280, 8, 8] -> [1, 1280, 16, 16]\n",
    "        out_vae = self.up_in(out_vae)\n",
    "\n",
    "        #[1, 1280, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 32, 32]\n",
    "        #out_down -> [1, 640, 16, 16],[1, 1280, 16, 16],[1, 1280, 16, 16]\n",
    "        out_vae = self.up_block0(out_vae=out_vae,\n",
    "                                 out_encoder=out_encoder,\n",
    "                                 time=time,\n",
    "                                 out_down=out_down)\n",
    "\n",
    "        #[1, 1280, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 64, 64]\n",
    "        #out_down -> [1, 320, 32, 32],[1, 640, 32, 32],[1, 640, 32, 32]\n",
    "        out_vae = self.up_block1(out_vae=out_vae,\n",
    "                                 out_encoder=out_encoder,\n",
    "                                 time=time,\n",
    "                                 out_down=out_down)\n",
    "\n",
    "        #[1, 640, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 64, 64]\n",
    "        #out_down -> [1, 320, 64, 64],[1, 320, 64, 64],[1, 320, 64, 64]\n",
    "        out_vae = self.up_block2(out_vae=out_vae,\n",
    "                                 out_encoder=out_encoder,\n",
    "                                 time=time,\n",
    "                                 out_down=out_down)\n",
    "\n",
    "        #----out----\n",
    "        #[1, 320, 64, 64] -> [1, 4, 64, 64]\n",
    "        out_vae = self.out(out_vae)\n",
    "\n",
    "        return out_vae\n",
    "\n",
    "\n",
    "UNet()(torch.randn(2, 4, 64, 64), torch.randn(2, 77, 768),\n",
    "       torch.LongTensor([26])).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import UNet2DConditionModel\n",
    "\n",
    "#加载预训练模型的参数\n",
    "params = UNet2DConditionModel.from_pretrained(\n",
    "    'models/diffsion_from_scratch.params', subfolder='unet')\n",
    "\n",
    "unet = UNet()\n",
    "\n",
    "#in\n",
    "unet.in_vae.load_state_dict(params.conv_in.state_dict())\n",
    "unet.in_time[0].load_state_dict(params.time_embedding.linear_1.state_dict())\n",
    "unet.in_time[2].load_state_dict(params.time_embedding.linear_2.state_dict())\n",
    "\n",
    "\n",
    "#down\n",
    "def load_tf(model, param):\n",
    "    model.norm_in.load_state_dict(param.norm.state_dict())\n",
    "    model.cnn_in.load_state_dict(param.proj_in.state_dict())\n",
    "\n",
    "    model.atten1.q.load_state_dict(\n",
    "        param.transformer_blocks[0].attn1.to_q.state_dict())\n",
    "    model.atten1.k.load_state_dict(\n",
    "        param.transformer_blocks[0].attn1.to_k.state_dict())\n",
    "    model.atten1.v.load_state_dict(\n",
    "        param.transformer_blocks[0].attn1.to_v.state_dict())\n",
    "    model.atten1.out.load_state_dict(\n",
    "        param.transformer_blocks[0].attn1.to_out[0].state_dict())\n",
    "\n",
    "    model.atten2.q.load_state_dict(\n",
    "        param.transformer_blocks[0].attn2.to_q.state_dict())\n",
    "    model.atten2.k.load_state_dict(\n",
    "        param.transformer_blocks[0].attn2.to_k.state_dict())\n",
    "    model.atten2.v.load_state_dict(\n",
    "        param.transformer_blocks[0].attn2.to_v.state_dict())\n",
    "    model.atten2.out.load_state_dict(\n",
    "        param.transformer_blocks[0].attn2.to_out[0].state_dict())\n",
    "\n",
    "    model.fc0.load_state_dict(\n",
    "        param.transformer_blocks[0].ff.net[0].proj.state_dict())\n",
    "\n",
    "    model.fc1.load_state_dict(\n",
    "        param.transformer_blocks[0].ff.net[2].state_dict())\n",
    "\n",
    "    model.norm_atten0.load_state_dict(\n",
    "        param.transformer_blocks[0].norm1.state_dict())\n",
    "    model.norm_atten1.load_state_dict(\n",
    "        param.transformer_blocks[0].norm2.state_dict())\n",
    "    model.norm_act.load_state_dict(\n",
    "        param.transformer_blocks[0].norm3.state_dict())\n",
    "\n",
    "    model.cnn_out.load_state_dict(param.proj_out.state_dict())\n",
    "\n",
    "\n",
    "def load_res(model, param):\n",
    "    model.time[1].load_state_dict(param.time_emb_proj.state_dict())\n",
    "\n",
    "    model.s0[0].load_state_dict(param.norm1.state_dict())\n",
    "    model.s0[2].load_state_dict(param.conv1.state_dict())\n",
    "\n",
    "    model.s1[0].load_state_dict(param.norm2.state_dict())\n",
    "    model.s1[2].load_state_dict(param.conv2.state_dict())\n",
    "\n",
    "    if isinstance(model.res, torch.nn.Module):\n",
    "        model.res.load_state_dict(param.conv_shortcut.state_dict())\n",
    "\n",
    "\n",
    "def load_down_block(model, param):\n",
    "    load_tf(model.tf0, param.attentions[0])\n",
    "    load_tf(model.tf1, param.attentions[1])\n",
    "\n",
    "    load_res(model.res0, param.resnets[0])\n",
    "    load_res(model.res1, param.resnets[1])\n",
    "\n",
    "    model.out.load_state_dict(param.downsamplers[0].conv.state_dict())\n",
    "\n",
    "\n",
    "load_down_block(unet.down_block0, params.down_blocks[0])\n",
    "load_down_block(unet.down_block1, params.down_blocks[1])\n",
    "load_down_block(unet.down_block2, params.down_blocks[2])\n",
    "\n",
    "load_res(unet.down_res0, params.down_blocks[3].resnets[0])\n",
    "load_res(unet.down_res1, params.down_blocks[3].resnets[1])\n",
    "\n",
    "#mid\n",
    "load_tf(unet.mid_tf, params.mid_block.attentions[0])\n",
    "load_res(unet.mid_res0, params.mid_block.resnets[0])\n",
    "load_res(unet.mid_res1, params.mid_block.resnets[1])\n",
    "\n",
    "#up\n",
    "load_res(unet.up_res0, params.up_blocks[0].resnets[0])\n",
    "load_res(unet.up_res1, params.up_blocks[0].resnets[1])\n",
    "load_res(unet.up_res2, params.up_blocks[0].resnets[2])\n",
    "unet.up_in[1].load_state_dict(\n",
    "    params.up_blocks[0].upsamplers[0].conv.state_dict())\n",
    "\n",
    "\n",
    "def load_up_block(model, param):\n",
    "    load_tf(model.tf0, param.attentions[0])\n",
    "    load_tf(model.tf1, param.attentions[1])\n",
    "    load_tf(model.tf2, param.attentions[2])\n",
    "\n",
    "    load_res(model.res0, param.resnets[0])\n",
    "    load_res(model.res1, param.resnets[1])\n",
    "    load_res(model.res2, param.resnets[2])\n",
    "\n",
    "    if isinstance(model.out, torch.nn.Module):\n",
    "        model.out[1].load_state_dict(param.upsamplers[0].conv.state_dict())\n",
    "\n",
    "\n",
    "load_up_block(unet.up_block0, params.up_blocks[1])\n",
    "load_up_block(unet.up_block1, params.up_blocks[2])\n",
    "load_up_block(unet.up_block2, params.up_blocks[3])\n",
    "\n",
    "#out\n",
    "unet.out[0].load_state_dict(params.conv_norm_out.state_dict())\n",
    "unet.out[2].load_state_dict(params.conv_out.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_vae = torch.randn(1, 4, 64, 64)\n",
    "out_encoder = torch.randn(1, 77, 768)\n",
    "time = torch.LongTensor([26])\n",
    "\n",
    "a = unet(out_vae=out_vae, out_encoder=out_encoder, time=time)\n",
    "b = params(out_vae, time, out_encoder).sample\n",
    "\n",
    "(a == b).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
